import numpy as np
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain.chains.conversational_retrieval.base import ConversationalRetrievalChain
from langchain.chains.retrieval import create_retrieval_chain
from langchain.memory import ConversationSummaryMemory
from langchain_community.document_loaders.pdf import PyPDFLoader
from langchain_community.vectorstores import Chroma
from langchain_core.prompts import ChatPromptTemplate
from langchain_text_splitters import RecursiveCharacterTextSplitter

from ChatGLM_new import xinghuo_llm
from langchain_community.document_loaders import WebBaseLoader
from langchain.document_loaders import TextLoader
from langchain_community.embeddings import SparkLLMTextEmbeddings

embeddings = SparkLLMTextEmbeddings(
    spark_app_id="849d9562",
    spark_api_key="e33637f01cebee4a4675eb479d77df12",
    spark_api_secret="ZmI0NGE5ZGUwNzQwMGZjYWFiYjg4MTIx",
)


print("############pdf文档读入##############")
# loader_pdf = PyPDFLoader(r"D:\05work\01telchina\04智能灯杆\005灯杆2.0\安全红线2.0.pdf")
# docs_pdf = loader_pdf.load()
# print(type(docs_pdf))
# print(docs_pdf[0].page_content)

print("############TXT文档的读入##############")
loader_txt = TextLoader(r'D:\03study\book\python\python_data_course-main\大模型产品开发导论\云岚宗.txt', encoding='utf8')
docs_txt = loader_txt.load()
#print(docs_txt[0].metadata,docs_txt[0].page_content[0:100])

print("############网页文档的读入##############")
# WEB_URL = "https://news.ifeng.com/c/8Y3TlIcTsj0"
# loader_html = WebBaseLoader(WEB_URL)
# docs_html = loader_html.load()
# print(docs_html[0].metadata,docs_html[0].page_content[0:100])

print("############文档的分割##############")
''' 
RecursiveCharacterTextSplitter 递归字符文本分割
RecursiveCharacterTextSplitter 将按不同的字符递归地分割(按照这个优先级["\n\n", "\n", " ", ""])，
    这样就能尽量把所有和语义相关的内容尽可能长时间地保留在同一位置
RecursiveCharacterTextSplitter需要关注的是4个参数：

* separators - 分隔符字符串数组
* chunk_size - 每个文档的字符数量限制
* chunk_overlap - 两份文档重叠区域的长度
* length_function - 长度计算函数
'''
# 导入递归字符文本分割器
text_splitter_txt = RecursiveCharacterTextSplitter(chunk_size = 22500, chunk_overlap = 0, separators=["\n\n", "\n", " ", "", "。", "，"])
# 导入文本
documents_txt = text_splitter_txt.split_documents(docs_txt)
# print(len(documents_txt))
# print(documents_txt[0].page_content)


print("############向量化##############")
query1 = "狗"
query2 = "猫"
query3 = "雨"

# 通过对应的 embedding 类生成 query 的 embedding。
# emb1 = embeddings.embed_query(query1)
# emb2 = embeddings.embed_query(query2)
# emb3 = embeddings.embed_query(query3)
# # 将返回结果转成 numpy 的格式，便于后续计算
# emb1 = np.array(emb1)
# emb2 = np.array(emb2)
# emb3 = np.array(emb3)
#
# print(np.dot(emb1, emb2))
# print(np.dot(emb3, emb2))
# print(np.dot(emb1, emb3))

print("##############向量数据库（Chroma）#######################")
# persist_directory允许我们将目录保存到磁盘上
# 注意from_documents每次运行都是把数据添加进去
# vectordb = Chroma.from_documents(documents=documents_txt, embedding=embeddings, persist_directory="./" )
# vectordb.persist()

vectordb_load = Chroma(
    persist_directory="./",
    embedding_function=embeddings
)
# print(vectordb_load._collection.count())
# print(vectordb_load.similarity_search("轻摆了摆手"))

print("##############构造检索式问答链#######################")
# 创建提示词模板
prompt = ChatPromptTemplate.from_template("""使用下面的语料来回答本模板最末尾的问题。如果你不知道问题的答案，直接回答 "我不知道"，禁止随意编造答案。
        为了保证答案尽可能简洁，你的回答必须不超过三句话，你的回答中不可以带有星号。
        请注意！在每次回答结束之后，你都必须接上 "感谢你的提问" 作为结束语
        以下是一对问题和答案的样例：
            请问：秦始皇的原名是什么
            秦始皇原名嬴政。感谢你的提问。

        以下是语料：
<context>
{context}
</context>

Question: {input}""")
# 创建检索链
document_chain = create_stuff_documents_chain(xinghuo_llm, prompt)

retriever = vectordb_load.as_retriever()
retrieval_chain = create_retrieval_chain(retriever, document_chain)

# 先检索向量数据库将内容填充到提示词模板中，再调用 LLM 生成答案
# response = retrieval_chain.invoke({
#     "input": "萧炎是怎样的强者?"
# })
# print(response["answer"])

print("###########构建一个检索式文档对话模型######################")

memory = ConversationSummaryMemory(
    llm=xinghuo_llm, memory_key="chat_history", return_messages=True
)
qa = ConversationalRetrievalChain.from_llm(verbose=True, llm=xinghuo_llm, retriever=retriever, memory=memory)

res = qa.invoke(
    {"question": "萧炎是谁"}
)
print(res["answer"])

res = qa.invoke(
    {"question": "他结婚了吗？"}
)
print(res["answer"])